//	Draw4DGPUFunctions.metal
//
//	© 2025 by Jeff Weeks
//	See TermsOfUse.txt

#include <metal_stdlib>
using namespace metal;
#include "Draw4DGPUDefinitions.h"


//	On a device with a wide-color display, we interpret
//	all colors as linear Display P3.  Because Metal does
//	all its rendering in linear Extended-Range sRGB
//	(at least on iOS and presumably on the upcoming
//	Apple Silicon Macs as well), on a device with
//	a wide-color display we must convert from linear Display P3
//	to linear Extended-Range sRGB.
//
//	Design notes:
//
//		For "rainbow edges" it's essential that the color conversion
//		happen here in the GPU fragment function Draw4DFragmentFunctionHSV().
//
//		For "single-color edges" we could, if we wanted to squeeze out
//		every last bit of performance, let the CPU pre-compute
//		the required XR sRGB colors.  But that small performance boost
//		isn't needed, so I decided to do the color conversion
//		here in the GPU vertex function Draw4DVertexFunctionRGB(),
//		for simplicity and especially for consistency with how
//		we handle the rainbow edges' color conversion.
//
constant bool	gUseWideColor	[[ function_constant(0) ]];

constant half3x3	gP3toXRsRGB =
					{
						//	Yes, I know, the following values contain
						//	way too many "significant" digits to fit
						//	into a half-precision float, but I'm leaving them
						//	there anyhow, for future reference.
						//	They're harmless.
						{ 1.2249401762805587, -0.0420569547096881, -0.0196375545903344},
						{-0.2249401762805597,  1.0420569547096874, -0.0786360455506319},
						{ 0.0000000000000001,  0.0000000000000000,  1.0982736001409661}
					};

//	The direction of maximal specular reflection is given by
//	the pre-computed vector
//
//		normalize(theLightDirection + theViewerDirection)
//
constant   half3 gLightDirection    = half3(-0.27, +0.27, -0.92);	//	points towards the light source
//constant half3 gViewerDirection   = half3( 0.00,  0.00, -1.00);	//	points towards the observer
constant   half3 gSpecularDirection = half3(-0.14, +0.14, -0.98);	//	direction of maximal specular reflection


//	RGB

struct VertexInputRGB
{
	float3	pos [[ attribute(VertexAttributePosition)	]];	//	position (x,y,z)
	half3	nor [[ attribute(VertexAttributeNormal)		]];	//	normal vector (nx, ny, nz)
};

struct VertexOutputRGBA
{
	float4	position	[[ position		]];
	half4	color		[[ user(color)	]];	//	premultiplied alpha (αR, αG, αB, α)
};

struct FragmentInputRGBA
{
	half4	color		[[ user(color)	]];	//	premultiplied alpha (αR, αG, αB, α)
};

vertex VertexOutputRGBA Draw4DVertexFunctionRGB(
	      VertexInputRGB				in				[[ stage_in							]],
	      constant Draw4DUniformData	&uniformData	[[ buffer(BufferIndexVFUniforms)	]],
	const device Draw4DInstanceDataRGB	*instanceData	[[ buffer(BufferIndexVFInstanceData)]],
	      ushort						iid				[[ instance_id						]]	)
{
	float4				tmpCameraPosition;
	VertexOutputRGBA	out;
	half3				tmpNormal;
	half				tmpDiffuseFactor,
						tmpSpecularFactor,
						tmpFogFactor;	//	0.0h = fully fogged;  1.0h = no fog
	half3				tmpRawColor,
						tmpXRsRGBColor,
						tmpShadedColor,
						tmpHighlightedColor,
						tmpFoggedColor;

	tmpCameraPosition	= instanceData[iid].itsModelViewMatrix * float4(in.pos, 1.0);
	out.position		= uniformData.itsProjectionMatrix * tmpCameraPosition;

	//	Note that it's OK to apply itsModelViewMatrix to the normal vector here,
	//	because itsModelViewMatrix does no dilation except for compressing rainbow edges
	//	in the direction orthogonal to the normal vectors.
	//
	tmpNormal			= ( half4x4(instanceData[iid].itsModelViewMatrix) * half4(in.nor, 0.0h) ).xyz;
	tmpDiffuseFactor	= 0.50h + 0.50h * max(0.0h, dot(gLightDirection, tmpNormal) );
	tmpSpecularFactor	= 0.25h * pow(max(0.0h, dot(gSpecularDirection, tmpNormal) ), 16.0h);

	tmpFogFactor		= 1.00h - 0.25h * clamp(half(tmpCameraPosition.z), 0.0h, 1.0h);

	tmpRawColor			= instanceData[iid].itsRGB;
	if (gUseWideColor)
	{
		//	Interpret tmpRawColor as linear Display P3
		//	and convert it to linear Extended-Range sRGB.
		tmpXRsRGBColor	= gP3toXRsRGB * tmpRawColor;
	}
	else
	{
		//	Interpret tmpRawColor as linear sRGB.
		tmpXRsRGBColor	= tmpRawColor;
	}
	tmpShadedColor		= tmpDiffuseFactor * tmpXRsRGBColor;
	tmpHighlightedColor	= tmpShadedColor
						+ half3(tmpSpecularFactor, tmpSpecularFactor, tmpSpecularFactor);
	tmpFoggedColor		= tmpFogFactor * tmpHighlightedColor;
	out.color			= instanceData[iid].itsOpacity * half4(tmpFoggedColor, 1.0h);

	return out;
}

fragment half4 Draw4DFragmentFunctionRGB(
	FragmentInputRGBA	in	[[ stage_in ]])
{
	return in.color;
}


//	HSV

struct VertexInputHSV
{
	float3	pos [[ attribute(VertexAttributePosition)	]];	//	position (x,y,z)
	half3	nor [[ attribute(VertexAttributeNormal)		]];	//	normal vector (nx, ny, nz)
	half	wgt [[ attribute(VertexAttributeMisc)		]];	//	weight ∈ {0.0, 1.0} used to interpolate between itsHSV0 and itsHSV1
};

struct VertexOutputHSV
{
	float4	position		[[ position			]];
	half3	hsvColor		[[ user(hsv)		]];
	half	diffuseFactor	[[ user(diffuse)	]],
			specularFactor	[[ user(specular)	]],
			fogFactor		[[ user(fog)		]],	//	0.0h = fully fogged;  1.0h = no fog
			opacity			[[ user(opacity)	]];
};

struct FragmentInputHSV
{
	half3	hsvColor		[[ user(hsv)		]];
	half	diffuseFactor	[[ user(diffuse)	]],
			specularFactor	[[ user(specular)	]],
			fogFactor		[[ user(fog)		]],	//	0.0h = fully fogged;  1.0h = no fog
			opacity			[[ user(opacity)	]];
};

vertex VertexOutputHSV Draw4DVertexFunctionHSV(
	      VertexInputHSV				in				[[ stage_in							]],
	      constant Draw4DUniformData	&uniformData	[[ buffer(BufferIndexVFUniforms)	]],
	const device Draw4DInstanceDataHSV	*instanceData	[[ buffer(BufferIndexVFInstanceData)]],
	      ushort						iid				[[ instance_id						]]	)
{
	float4			tmpCameraPosition;
	VertexOutputHSV	out;
	half3			tmpNormal;
	
	tmpCameraPosition	= instanceData[iid].itsModelViewMatrix * float4(in.pos, 1.0);
	out.position		= uniformData.itsProjectionMatrix * tmpCameraPosition;

	out.hsvColor		=       in.wgt      * instanceData[iid].itsHSV0
						  + (1.0h - in.wgt) * instanceData[iid].itsHSV1;

	//	Note that it's OK to apply itsModelViewMatrix to the normal vector here,
	//	because itsModelViewMatrix does no dilation except for compressing rainbow edges
	//	in the direction orthogonal to the normal vectors.
	//
	tmpNormal			= ( half4x4(instanceData[iid].itsModelViewMatrix) * half4(in.nor, 0.0h) ).xyz;
	out.diffuseFactor	= 0.50h + 0.50h * max(0.0h, dot(gLightDirection, tmpNormal) );
	out.specularFactor	= 0.25h * pow(max(0.0h, dot(gSpecularDirection, tmpNormal) ), 16.0h);

	out.fogFactor		= 1.00h - 0.25h * clamp(half(tmpCameraPosition.z), 0.0h, 1.0h);
	
	//	In the present implementation, all content that
	//	uses Draw4DVertexFunctionHSV() has itsOpacity = 1.0.
	//	Nevertheless I've left the opacity code here for future flexibility.
	//
	out.opacity			= instanceData[iid].itsOpacity;	//	= 1.0

	return out;
}

fragment half4 Draw4DFragmentFunctionHSV(
	FragmentInputHSV	in	[[ stage_in ]])
{
	//	Convert HSV to RGB using the algorithm given at
	//
	//		https://lolengine.net/blog/2013/07/27/rgb-to-hsv-in-glsl
	//
	//	The idea behind the algorithm seems to be as follows.
	//	Think of, say, the red component R as a function of the hue H.
	//	The graph looks like this:
	//
	//		___            ___
	//		   \          /
	//		    \        /
	//		     \______/
	//
	//	The function abs(6*H - 3) gives a sawtooth:
	//
	//		\                /
	//		 \              /
	//		  \............/
	//		   \          /
	//		    \        /
	//		     \....../
	//		      \    /
	//	           \  /
	//		        \/
	//
	//	which can be offset by 1 unit vertically and then clamped
	//	to give the desired graph for R as shown in the first figure above.
	
	half	tmpHue,
			tmpSaturation,
			tmpValue;
	half3	tmpSawtooth,
			tmpRawColor,
			tmpXRsRGBColor,
			tmpShadedColor,
			tmpHighlightedColor,
			tmpFoggedColor;
	
	tmpHue			= in.hsvColor[0];
	tmpSaturation	= in.hsvColor[1];
	tmpValue		= in.hsvColor[2];

	tmpSawtooth = abs
					(
							fract
							(
								half3(tmpHue, tmpHue, tmpHue) + half3(1.00000h, 0.66666h, 0.33333h)
							)
							* 6.0h
						-
							half3(3.0h, 3.0h, 3.0h)
					);
	tmpRawColor =
		tmpValue
	  * mix
	  	(
			half3(1.0h, 1.0h, 1.0h),									//	pure white
			clamp(tmpSawtooth - half3(1.0h, 1.0h, 1.0h), 0.0h, 1.0h),	//	pure saturated color
			tmpSaturation
		);
	if (gUseWideColor)
	{
		//	Interpret tmpRawColor as linear Display P3
		//	and convert it to linear Extended-Range sRGB.
		tmpXRsRGBColor	= gP3toXRsRGB * tmpRawColor;
	}
	else
	{
		//	Interpret tmpRawColor as linear sRGB.
		tmpXRsRGBColor	= tmpRawColor;
	}
	tmpShadedColor		= in.diffuseFactor * tmpXRsRGBColor;
	tmpHighlightedColor	= tmpShadedColor
						+ half3(in.specularFactor, in.specularFactor, in.specularFactor);
	tmpFoggedColor		= in.fogFactor * tmpHighlightedColor;

	//	Return the fragment with premultiplied alpha (αR, αG, αB, α)
	return in.opacity * half4(tmpFoggedColor, 1.0h);
}


//	Texture

struct VertexInputTex
{
	float3	pos [[ attribute(VertexAttributePosition)	]];	//	position (x,y,z)
	half3	nor [[ attribute(VertexAttributeNormal)		]];	//	normal vector (nx, ny, nz)
	float2	tex [[ attribute(VertexAttributeMisc)		]];	//	texture coordinates (u,v)
};

struct VertexOutputTex
{
	float4	position			[[ position			]];
	float2	texCoords			[[ user(texCoords)	]];
	half	brightnessFactor	[[ user(brightness)	]];
};

struct FragmentInputTex
{
	float2	texCoords			[[ user(texCoords)	]];
	half	brightnessFactor	[[ user(brightness)	]];
};

vertex VertexOutputTex Draw4DVertexFunctionTex(
	      VertexInputTex				in				[[ stage_in							]],
	      constant Draw4DUniformData	&uniformData	[[ buffer(BufferIndexVFUniforms)	]],
	const device Draw4DInstanceDataTex	*instanceData	[[ buffer(BufferIndexVFInstanceData)]],
	      ushort						iid				[[ instance_id						]]	)
{
	float4			tmpCameraPosition;
	VertexOutputTex	out;
	half3			tmpNormal;
	half			tmpDiffuseFactor,
					tmpFogFactor;	//	0.0h = fully fogged;  1.0h = no fog
	
	tmpCameraPosition	= instanceData[iid].itsModelViewMatrix * float4(in.pos, 1.0);
	out.position		= uniformData.itsProjectionMatrix * tmpCameraPosition;
	
	out.texCoords		= in.tex;

	//	Note that it's OK to apply itsModelViewMatrix to the normal vector here,
	//	because itsModelViewMatrix does no dilation except for compressing rainbow edges
	//	in the direction orthogonal to the normal vectors.
	//
	tmpNormal			= ( half4x4(instanceData[iid].itsModelViewMatrix) * half4(in.nor, 0.0h) ).xyz;
	tmpDiffuseFactor	= 0.50h + 0.50h * max(0.0h, dot(gLightDirection, tmpNormal) );
	tmpFogFactor		= 1.00h - 0.25h * clamp( half(tmpCameraPosition.z), 0.0h, 1.0h);

	out.brightnessFactor = tmpDiffuseFactor * tmpFogFactor;
	
	return out;
}

fragment half4 Draw4DFragmentFunctionTex(
	FragmentInputTex	in				[[ stage_in							]],
	texture2d<half>		texture			[[ texture(TextureIndexFFPrimary)	]],	//	contains brightness in red channel
	sampler				textureSampler	[[ sampler(SamplerIndexFFPrimary)	]])
{
	half	tmpBrightness;

	tmpBrightness = in.brightnessFactor * texture.sample(textureSampler, in.texCoords).r;

	return half4(tmpBrightness, tmpBrightness, tmpBrightness, 1.0h);
}


#pragma mark -
#pragma mark Compute function

kernel void Draw4DComputeFunctionMakeGridTexture(
	texture2d<half, access::write>	aTexture	[[ texture(TextureIndexCFDst)		]],
	constant ushort2				&someLimits	[[ buffer(BufferIndexCFGridLimits)	]],
	ushort2							aGridID		[[ thread_position_in_grid			]])
{
	half4	theColor;
	
	theColor = (   (aGridID.x <  someLimits[0]	//	part of  left  border
				 || aGridID.x >= someLimits[1]	//	part of  right border
				 || aGridID.y <  someLimits[0]	//	part of bottom border
				 || aGridID.y >= someLimits[1])	//	part of  top   border
				?
				 half4(1.0, 1.0, 1.0, 1.0)	//	pixel on grid line
				:
				 half4(0.5, 0.5, 0.5, 1.0)	//	pixel not on grid line
			   );

	aTexture.write(theColor, aGridID);
}
